{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import random \n",
    "from shutil import copyfile\n",
    "import pydicom as dicom\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters here\n",
    "savepath = 'data'\n",
    "seed = 0\n",
    "np.random.seed(seed) # Reset the seed so all runs are the same.\n",
    "random.seed(seed)\n",
    "MAXVAL = 255  # Range [0 255]\n",
    "\n",
    "# path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset\n",
    "imgpath = '../covid-chestxray-dataset/images' \n",
    "csvpath = '../covid-chestxray-dataset/metadata.csv'\n",
    "\n",
    "# path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge\n",
    "kaggle_datapath = 'rsna-pneumonia-detection-challenge'\n",
    "kaggle_csvname = 'stage_2_detailed_class_info.csv' # get all the normal from here\n",
    "kaggle_csvname2 = 'stage_2_train_labels.csv' # get all the 1s from here since 1 indicate pneumonia\n",
    "kaggle_imgpath = 'stage_2_train_images'\n",
    "\n",
    "# parameters for COVIDx dataset\n",
    "train = []\n",
    "test = []\n",
    "test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
    "train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
    "\n",
    "mapping = dict()\n",
    "mapping['COVID-19'] = 'COVID-19'\n",
    "mapping['SARS'] = 'pneumonia'\n",
    "mapping['MERS'] = 'pneumonia'\n",
    "mapping['Streptococcus'] = 'pneumonia'\n",
    "mapping['Normal'] = 'normal'\n",
    "mapping['Lung Opacity'] = 'pneumonia'\n",
    "mapping['1'] = 'pneumonia'\n",
    "\n",
    "# train/test split\n",
    "split = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# adapted from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/datasets.py#L814\n",
    "csv = pd.read_csv(csvpath, nrows=None)\n",
    "idx_pa = csv[\"view\"] == \"PA\"  # Keep only the PA view\n",
    "csv = csv[idx_pa]\n",
    "\n",
    "pneumonias = [\"COVID-19\", \"SARS\", \"MERS\", \"ARDS\", \"Streptococcus\"]\n",
    "pathologies = [\"Pneumonia\",\"Viral Pneumonia\", \"Bacterial Pneumonia\", \"No Finding\"] + pneumonias\n",
    "pathologies = sorted(pathologies)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get non-COVID19 viral, bacteria, and COVID-19 infections from covid-chestxray-dataset\n",
    "# stored as patient id, image filename and label\n",
    "filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []}\n",
    "count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}\n",
    "for index, row in csv.iterrows():\n",
    "    f = row['finding']\n",
    "    if f in mapping:\n",
    "        count[mapping[f]] += 1\n",
    "        entry = [int(row['patientid']), row['filename'], mapping[f]]\n",
    "        filename_label[mapping[f]].append(entry)\n",
    "\n",
    "print('Data distribution from covid-chestxray-dataset:')\n",
    "print(count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add covid-chestxray-dataset into COVIDx dataset\n",
    "# since covid-chestxray-dataset doesn't have test dataset\n",
    "# split into train/test by patientid\n",
    "# for COVIDx:\n",
    "# patient 8 is used as non-COVID19 viral test\n",
    "# patient 31 is used as bacterial test\n",
    "# patients 19, 20, 36, 42, 86 are used as COVID-19 viral test\n",
    "\n",
    "for key in filename_label.keys():\n",
    "    arr = np.array(filename_label[key])\n",
    "    if arr.size == 0:\n",
    "        continue\n",
    "    # split by patients\n",
    "    # num_diff_patients = len(np.unique(arr[:,0]))\n",
    "    # num_test = max(1, round(split*num_diff_patients))\n",
    "    # select num_test number of random patients\n",
    "    if key == 'pneumonia':\n",
    "        test_patients = ['8', '31']\n",
    "    elif key == 'COVID-19':\n",
    "        test_patients = ['19', '20', '36', '42', '86'] # random.sample(list(arr[:,0]), num_test)\n",
    "    else: \n",
    "        test_patients = []\n",
    "    print('Key: ', key)\n",
    "    print('Test patients: ', test_patients)\n",
    "    # go through all the patients\n",
    "    for patient in arr:\n",
    "        if patient[0] in test_patients:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))\n",
    "            test.append(patient)\n",
    "            test_count[patient[2]] += 1\n",
    "        else:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))\n",
    "            train.append(patient)\n",
    "            train_count[patient[2]] += 1\n",
    "\n",
    "print('test count: ', test_count)\n",
    "print('train count: ', train_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add normal and rest of pneumonia cases from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge\n",
    "csv_normal = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname), nrows=None)\n",
    "csv_pneu = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname2), nrows=None)\n",
    "patients = {'normal': [], 'pneumonia': []}\n",
    "\n",
    "for index, row in csv_normal.iterrows():\n",
    "    if row['class'] == 'Normal':\n",
    "        patients['normal'].append(row['patientId'])\n",
    "\n",
    "for index, row in csv_pneu.iterrows():\n",
    "    if int(row['Target']) == 1:\n",
    "        patients['pneumonia'].append(row['patientId'])\n",
    "\n",
    "for key in patients.keys():\n",
    "    arr = np.array(patients[key])\n",
    "    if arr.size == 0:\n",
    "        continue\n",
    "    # split by patients \n",
    "    # num_diff_patients = len(np.unique(arr))\n",
    "    # num_test = max(1, round(split*num_diff_patients))\n",
    "    test_patients = np.load('rsna_test_patients_{}.npy'.format(key)) # random.sample(list(arr), num_test), download the .npy files from the repo.\n",
    "    # np.save('rsna_test_patients_{}.npy'.format(key), np.array(test_patients))\n",
    "    for patient in arr:\n",
    "        ds = dicom.dcmread(os.path.join(kaggle_datapath, kaggle_imgpath, patient + '.dcm'))\n",
    "        pixel_array_numpy = ds.pixel_array\n",
    "        imgname = patient + '.png'\n",
    "        if patient in test_patients:\n",
    "            cv2.imwrite(os.path.join(savepath, 'test', imgname), pixel_array_numpy)\n",
    "            test.append([patient, imgname, key])\n",
    "            test_count[key] += 1\n",
    "        else:\n",
    "            cv2.imwrite(os.path.join(savepath, 'train', imgname), pixel_array_numpy)\n",
    "            train.append([patient, imgname, key])\n",
    "            train_count[key] += 1\n",
    "\n",
    "print('test count: ', test_count)\n",
    "print('train count: ', train_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# final stats\n",
    "print('Final stats')\n",
    "print('Train count: ', train_count)\n",
    "print('Test count: ', test_count)\n",
    "print('Total length of train: ', len(train))\n",
    "print('Total length of test: ', len(test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run this cell when adding in new covid data from covid-chextxray-dataset\n",
    "\n",
    "# load in current train/test information\n",
    "'''train_filepath = 'train_split_v2.txt'\n",
    "test_filepath = 'test_split_v2.txt'\n",
    "file = open(train_filepath, 'r') \n",
    "trainfiles = file.readlines() \n",
    "trainfiles = np.array([line.split() for line in trainfiles])\n",
    "file = open(test_filepath, 'r')\n",
    "testfiles = file.readlines()\n",
    "testfiles = np.array([line.split() for line in testfiles])\n",
    "\n",
    "# find the new entries in csv \n",
    "new_entries = []\n",
    "for key in filename_label.keys():\n",
    "    arr = np.array(filename_label[key])\n",
    "    if arr.size == 0:\n",
    "        continue\n",
    "    for patient in arr:\n",
    "        if patient[1] not in trainfiles and patient[1] not in testfiles:\n",
    "            # if key is normal, bacteria or viral add to train folder\n",
    "            if key in ['normal', 'pneumonia']:\n",
    "                copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))\n",
    "                train.append(patient)\n",
    "                train_count[patient[2]] += 1\n",
    "            else: \n",
    "                new_entries.append(patient)\n",
    "new_entries = np.array(new_entries)\n",
    "\n",
    "# 10% of new entries should go into in test\n",
    "if new_entries.size > 0:\n",
    "    num_diff_patients = len(np.unique(new_entries[:,0]))\n",
    "    num_test = max(1, round(split*num_diff_patients))\n",
    "\n",
    "    i = 0\n",
    "    used_i = []\n",
    "    # insert patients who are already in dataset into the respective train/test\n",
    "    for patient in new_entries:\n",
    "        if patient[0] in trainfiles:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))\n",
    "            train.append(patient)\n",
    "            train_count[patient[2]] += 1\n",
    "            used_i.append(i)\n",
    "        elif patient[0] in testfiles:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))\n",
    "            test.append(patient)\n",
    "            test_count[patient[2]] += 1\n",
    "            used_i.append(i)\n",
    "        i += 1\n",
    "    # delete patients who have already been inserted\n",
    "    new_entries = np.delete(new_entries, used_i, axis=0)\n",
    "\n",
    "    # select num_test number of random patients\n",
    "    test_patients = random.sample(list(new_entries[:,0]), num_test)\n",
    "    print('test patients: ', test_patients)\n",
    "    # add to respective train/test folders\n",
    "    for patient in new_entries:\n",
    "        if patient[0] in test_patients:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))\n",
    "            test.append(patient)\n",
    "            test_count[patient[2]] += 1\n",
    "        else:\n",
    "            copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))\n",
    "            train.append(patient)\n",
    "            train_count[patient[2]] += 1\n",
    "\n",
    "print('added test count: ', test_count)\n",
    "print('added train count: ', train_count)'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export to train and test csv\n",
    "# format as patientid, filename, label, separated by a space\n",
    "train_file = open(\"train_split_v2.txt\",\"a\") \n",
    "for sample in train:\n",
    "    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\\n'\n",
    "    train_file.write(info)\n",
    "\n",
    "train_file.close()\n",
    "\n",
    "test_file = open(\"test_split_v2.txt\", \"a\")\n",
    "for sample in test:\n",
    "    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\\n'\n",
    "    test_file.write(info)\n",
    "\n",
    "test_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (covid)",
   "language": "python",
   "name": "covid"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
